{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2 style=\"color: Orange;\">简介</h2>\n",
    "\n",
    "此脚本主要介绍深度算法中常用的向量、列表操作，包括 Indexing，Reorganization, List operation, 以及 torch 基础模块。向量主要有两种类型，numpy.array 和 torch.tensor, 对两类向量进行操作的 API 都有对应的版本，但在参数传递上存在着不同。\n",
    "\n",
    "*参考资料：*\n",
    "- [利用 Python 进行数据分析 (Wes Mckinney)]()\n",
    "- [一文读懂 Pytorch 中的 Tensor View 机制](https://cloud.tencent.com/developer/article/1942122)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.2\n",
      "2.6.0\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "print(np.__version__)\n",
    "print(torch.__version__)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h3 style=\"color: Salmon;\">Indexing</h3>\n",
    "数组或向量索引可以选中数据中的子集或某个单元数组。在深度网络中通常为多维数组的情况，并且在索引时通常与其它 api 结合使用。使用场景包括从批数据中获取目标实例的特征、从结果中获取目标图像的概率、选取，替换或删除目标特征等。\n",
    "\n",
    "在使用不同索引时，要注意是否在相同的内存地址操作，还是在拷贝后数组上操作。并且要注意返回值是数值还是索引。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">初始化数组</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 6)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 在 numpy 和 torch 的数组与向量中存在着视图的概念，当不显示地复制向量时，一切的操作都与相当于在原视图上操作。\n",
    "# 例如对切片的数值修改后，原向量的数值同样会被更改\n",
    "inps = np.array([[1, 2, 3, 4, 5, 6], \n",
    "                 [7, 8, 9, 10, 11, 12]], dtype=np.int8)\n",
    "\n",
    "inps.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.int8(8)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps[1, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.int8(6)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps[0, -1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([2, 8], dtype=int8)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps[:, 1]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">切片索引</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1, 2],\n",
       "       [7, 8]], dtype=int8)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps[:, 0:2] # 这里要注意索引包含起始不包含最后"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 3, 5], dtype=int8)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps[0, 0::2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 3, 5], dtype=int8)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps[0, 0:-1:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[6, 4, 2]], dtype=int8)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps[:1, -1::-2]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">* 布尔索引</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 1.87212143,  0.08087614,  1.39863018,  1.27784114, -0.52723119,\n",
       "        -0.61309439],\n",
       "       [-1.9771694 , -0.31594087,  0.13091146,  0.01012644,  0.245643  ,\n",
       "        -0.10345722]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 布尔操作总是返回拷贝值\n",
    "rand_inps = np.random.randn(2, 6)\n",
    "inps = rand_inps.copy()\n",
    "rand_inps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[1., 0., 1., 1., 0., 0.],\n",
       "       [0., 0., 0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 这是一个简单的将分割概率转换为 0-1 分割标签的方法\n",
    "rand_inps[rand_inps > 0.5] = 1\n",
    "rand_inps[rand_inps <= 0.5] = 0\n",
    "\n",
    "rand_inps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ True, False,  True,  True, False, False],\n",
       "       [False, False, False, False, False, False]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rand_inps = np.bool_(rand_inps)\n",
    "rand_inps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 1.87212143,  1.39863018,  1.27784114],\n",
       "       [-1.9771694 ,  0.13091146,  0.01012644]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps[:, rand_inps[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.08087614, -0.52723119, -0.61309439],\n",
       "       [-0.31594087,  0.245643  , -0.10345722]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 可以使用的逻辑运算包括 取反：～、  和：&、  或：｜、  不等：！=\n",
    "inps[:, ~rand_inps[0]]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">神奇索引</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[ 0,  1,  2,  3,  4,  5],\n",
       "        [ 6,  7,  8,  9, 10, 11]],\n",
       "\n",
       "       [[12, 13, 14, 15, 16, 17],\n",
       "        [18, 19, 20, 21, 22, 23]]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 神奇索引总是返回拷贝值\n",
    "inps = np.arange(24).reshape(2, 2, 6)\n",
    "inps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.int64(11)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps[0, 1, 5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 7,  9, 11])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps[0, 1, [1, 3, 5]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 7,  9, 11],\n",
       "       [19, 21, 23]])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps[:, [1], [1, 3, 5]]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">where</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "        1]),\n",
       " array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,\n",
       "        1]),\n",
       " array([1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4,\n",
       "        5]))"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 当传入布尔向量时，返回原数据格式索引\n",
    "index = np.where(inps > 0)\n",
    "index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[0, 0, 0, 0, 0, 0],\n",
       "        [1, 1, 1, 1, 1, 1]],\n",
       "\n",
       "       [[1, 1, 1, 1, 1, 1],\n",
       "        [1, 1, 1, 1, 1, 1]]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 传入条件时，对元素进行操纵，条件为真时为 x，否则为 y\n",
    "np.where(inps > 5, 1, 0)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">gather</span>\n",
    "\n",
    "out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0\n",
    "\n",
    "out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1\n",
    "\n",
    "out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2\n",
    "\n",
    "一个班级调整座位，新的座位表可能每个小朋友都有新位置，也允许小朋友转出，但不允许新的小朋友转入\n",
    "\n",
    "通过索引告诉目标向量在哪些位置应该是向量中哪个位置元素的值\n",
    "\n",
    "gather 与 scatter 的重要性在于，相同的操作如果采用自定义的循环的赋值或替换函数，会消耗大量的计算资源与时间"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 1],\n",
       "        [4, 3]])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 将 t 向量 00，01，10，11 位置的值 替换为 00，00，11，10\n",
    "# dim = 1 指替换向量的哪一维度，索引代表将维度中的值替换为什么\n",
    "t = torch.tensor([[1, 2], [3, 4]])\n",
    "torch.gather(input=t, dim=1, index=torch.tensor([[0, 0], [1, 0]]))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">scatter</span>\n",
    "\n",
    "self[index[i][j][k]][j][k] = src[i][j][k]   if dim == 0\n",
    "\n",
    "self[i][index[i][j][k]][k] = src[i][j][k]   if dim == 1\n",
    "\n",
    "self[i][j][index[i][j][k]] = src[i][j][k]   if dim == 2\n",
    "\n",
    "一个班级调整座位，有新的小朋友转入，对应位置小朋友转出\n",
    "\n",
    "通过索引告诉目标向量在哪些位置应该是给定向量中哪些位置的值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1,  2,  3,  4,  5],\n",
       "        [ 6,  7,  8,  9, 10]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "src = torch.arange(1, 11).reshape((2, 5))\n",
    "src"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 0, 0, 4, 0],\n",
       "        [0, 2, 0, 0, 0],\n",
       "        [0, 0, 3, 0, 0]])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "index = torch.tensor([[0, 1, 2, 0]])\n",
    "torch.zeros(3, 5, dtype=src.dtype).scatter_(dim=0, index=index, src=src)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2, 3, 0, 0],\n",
       "        [6, 7, 0, 0, 8],\n",
       "        [0, 0, 0, 0, 0]])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "index = torch.tensor([[0, 1, 2], [0, 1, 4]])\n",
    "torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h3 style=\"color: Salmon;\">Reorganization</h3>\n",
    "重组是指变换数组或向量的形状，以下内容以 torch API 作为示例。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 1, 2, 1])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((3, 1, 2, 1))\n",
    "inps.size()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">expand</span>\n",
    "\n",
    "此函数将维度值为 1 的维度扩展为制定维度，通常用于将两个待操作向量扩展到相同尺寸，类似的操作有 broadcast 和 repeat 等"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 5])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((5))\n",
    "\n",
    "out_expand = inps.expand(2, -1)\n",
    "out_expand.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 5])"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((1, 5))\n",
    "\n",
    "out_expand = inps.expand(2, -1)\n",
    "out_expand.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 3, 2, 5])"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((3, 1, 2, 1))\n",
    "\n",
    "out_expand = inps.expand(-1, 3, 2, 5)\n",
    "out_expand.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4637944192, 4637944192, 4637944192)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps.untyped_storage().data_ptr(), out_expand.untyped_storage().data_ptr(), out_expand.data_ptr()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">repeat</span> \n",
    "\n",
    "repeat 拷贝数据，而 expand 是原始数据的一个视图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 6])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.arange(6).view(2, 3)\n",
    "\n",
    "out_repeat = inps.repeat((2, 2))\n",
    "out_repeat.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 9])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_repeat = inps.repeat((1, 3))\n",
    "out_repeat.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5152648960, 5109241472)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_repeat.untyped_storage().data_ptr(), inps.untyped_storage().data_ptr()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">narrow</span>\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0.6055, -1.2727, -1.1776, -0.3959],\n",
       "         [-0.1464, -2.1016, -0.1463, -1.7791],\n",
       "         [ 0.3520, -0.3829,  0.0152, -1.4906]],\n",
       "\n",
       "        [[-1.3540, -0.1396, -2.2507, -0.8397],\n",
       "         [-0.9213,  0.4148,  1.0474,  0.1667],\n",
       "         [-0.5506, -0.7384,  0.1772,  0.5518]]])"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((2, 3, 4))\n",
    "inps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "narrow() received an invalid combination of arguments - got (length=int, strat=int, dim=int, ), but expected one of:\n * (int dim, Tensor start, int length)\n      didn't match because some of the keywords were incorrect: strat\n * (int dim, int start, int length)\n      didn't match because some of the keywords were incorrect: strat\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[35], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m out_narrow \u001b[38;5;241m=\u001b[39m \u001b[43minps\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnarrow\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrat\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlength\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m      2\u001b[0m out_narrow\u001b[38;5;241m.\u001b[39msize()\n",
      "\u001b[0;31mTypeError\u001b[0m: narrow() received an invalid combination of arguments - got (length=int, strat=int, dim=int, ), but expected one of:\n * (int dim, Tensor start, int length)\n      didn't match because some of the keywords were incorrect: strat\n * (int dim, int start, int length)\n      didn't match because some of the keywords were incorrect: strat\n"
     ]
    }
   ],
   "source": [
    "out_narrow = inps.narrow(dim=0, strat=0, length=1)\n",
    "out_narrow.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1, 4])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_narrow = inps.narrow(1, 0, 1)\n",
    "out_narrow.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 2, 4])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_narrow = inps.narrow(1, 1, 2)\n",
    "out_narrow.size()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">squeeze & unsqueeze</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 3, 4])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((2, 3, 4, 1))\n",
    "\n",
    "out_squeeze = inps.squeeze(-1)\n",
    "out_squeeze.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 2, 3, 4])"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((2, 3, 4))\n",
    "\n",
    "out_squeeze = inps.unsqueeze(0)\n",
    "out_squeeze.size()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">unfold</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 2., 3., 4., 5., 6., 7.])"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.arange(1., 8)\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 2.],\n",
       "        [2., 3.],\n",
       "        [3., 4.],\n",
       "        [4., 5.],\n",
       "        [5., 6.],\n",
       "        [6., 7.]])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.unfold(0, 2, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 2.],\n",
       "        [3., 4.],\n",
       "        [5., 6.]])"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.unfold(0, 2, 2)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">unbind</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 128, 128])"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((4, 1, 128, 128))\n",
    "\n",
    "out_unbind = inps.unbind(0)\n",
    "out_unbind[0].size()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">chunk</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tuple"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((3, 4, 2, 2))\n",
    "\n",
    "out_chunk = inps.chunk(chunks=2, dim=1)\n",
    "type(out_chunk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(out_chunk)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 2, 2, 2])"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_chunk[0].size()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">stack</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 1, 128, 128])"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps_1 = torch.randn((4, 1, 128, 128))\n",
    "inps_2 = torch.randn((4, 1, 128, 128))\n",
    "inps_3 = torch.randn((4, 1, 128, 128))\n",
    "\n",
    "y = torch.stack([inps_1, inps_2, inps_3], dim=1)\n",
    "y.size()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">cat</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 128, 128])"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps_1 = torch.randn((4, 1, 128, 128))\n",
    "inps_2 = torch.randn((4, 1, 128, 128))\n",
    "inps_3 = torch.randn((4, 1, 128, 128))\n",
    "\n",
    "y = torch.cat([inps_1, inps_2, inps_3], dim=1)\n",
    "y.size()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">view</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 16])"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((4, 3, 4, 4))\n",
    "\n",
    "out_view = inps.view(4, 3, 4 * 4)\n",
    "out_view.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 4, 4])"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_review = out_view.view(4, 3, 4, 4)\n",
    "out_review.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.equal(out_review, out_review)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">reshape</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 16])"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((4, 3, 4, 4))\n",
    "\n",
    "out_view = inps.reshape(4, 3, 4 * 4)\n",
    "out_view.size()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">permute</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 4, 3])"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((2, 3, 4))\n",
    "\n",
    "out_permute = inps.permute([0, 2, 1])\n",
    "out_permute.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 4, 2])"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_permute = inps.permute([1, 2, 0])\n",
    "out_permute.size()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">transpose / .T</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 2, 4])"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((2, 3, 4))\n",
    "\n",
    "out_transpose = inps.transpose(0, 1)\n",
    "out_transpose.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 4, 3])"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_transpose = inps.transpose(2, 1)\n",
    "out_transpose.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-56-aa4155aede8c>:1: UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp:3575.)\n",
      "  out_T = inps.T\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 3, 2])"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_T = inps.T\n",
    "out_T.size()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">einsum</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from einops import rearrange\n",
    "from torch import einsum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inps = torch.tensor([\n",
    "    [[1, 2, 3], [4, 5, 6]],\n",
    "    [[-1, -2, -3], [-4, -5, -6]]\n",
    "], dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 1.,  4.],\n",
       "         [ 2.,  5.],\n",
       "         [ 3.,  6.]],\n",
       "\n",
       "        [[-1., -4.],\n",
       "         [-2., -5.],\n",
       "         [-3., -6.]]])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 矩阵转置\n",
    "rearrange(inps, \"b c d -> b d c\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.,  2.,  3.,  4.,  5.,  6.],\n",
       "        [-1., -2., -3., -4., -5., -6.]])"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 向量合并\n",
    "rearrange(inps, 'b c d -> b (c d)')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 1.,  2.,  3.],\n",
       "         [ 4.,  5.,  6.]],\n",
       "\n",
       "        [[-1., -2., -3.],\n",
       "         [-4., -5., -6.]]])"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 向量拆分\n",
    "rearrange(rearrange(inps, 'b c d -> b (c d)'), 'b (c d) -> b c d', c=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[14., 32.],\n",
       "         [32., 77.]],\n",
       "\n",
       "        [[14., 32.],\n",
       "         [32., 77.]]])"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 矩阵乘法\n",
    "einsum(\"b i d, b j d -> b i j\", inps, inps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[14., 32.],\n",
       "         [32., 77.]],\n",
       "\n",
       "        [[14., 32.],\n",
       "         [32., 77.]]])"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "einsum(\"b i d, b d j -> b i j\", inps, rearrange(inps, 'b c d -> b d c'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 1.,  4.,  9.],\n",
       "         [16., 25., 36.]],\n",
       "\n",
       "        [[ 1.,  4.,  9.],\n",
       "         [16., 25., 36.]]])"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 元素乘法\n",
    "einsum(\"b i d, b i d -> b i d\", inps, inps)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h3 style=\"color: Salmon;\">Numerical</h3>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inps = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">diagonal</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1, 5, 9])"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps.diagonal()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2, 6])"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps.diagonal(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([4, 8])"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps.diagonal(-1)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">ceil</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2., 3., 4.])"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.tensor([1.1, 2.5, 3.7])\n",
    "\n",
    "inps.ceil()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">floor</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 2., 3.])"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.tensor([1.1, 2.5, 3.7])\n",
    "\n",
    "inps.floor()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">clip / clamp</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2.0000, 2.5000, 3.0000])"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.tensor([1.1, 2.5, 3.7])\n",
    "\n",
    "inps.clamp(min=2, max=3)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">lerp</span>\n",
    "\n",
    "$out_{i}=start_{i} + weight_{i} \\times (end_{i}-start_{i})$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 2., 3., 4.])"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "start = torch.arange(1., 5.)\n",
    "start"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([10., 10., 10., 10.])"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "end = torch.empty(4).fill_(10)\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([5.5000, 6.0000, 6.5000, 7.0000])"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "start = torch.arange(1., 5.)\n",
    "end = torch.empty(4).fill_(10)\n",
    "torch.lerp(input=start, end=end, weight=0.5)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">mm / bmm / matmul</span>\n",
    "\n",
    "mm 只能用于二维矩阵乘法\n",
    "\n",
    "bmm 用于带 batch 矩阵乘法\n",
    "\n",
    "matmul 可用于元素乘法与矩阵乘法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 2.6112, -2.2565],\n",
       "        [-2.2565,  3.5055]])"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((2, 3)).to(torch.float32)\n",
    "\n",
    "torch.mm(inps, inps.t())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "mat1 and mat2 shapes cannot be multiplied (2x3 and 2x3)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[74], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m torch\u001b[39m.\u001b[39;49mmm(inps, inps)\n",
      "\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (2x3 and 2x3)"
     ]
    }
   ],
   "source": [
    "torch.mm(inps, inps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 2.2204, -5.1686],\n",
       "         [-5.1686, 14.5050]]])"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.randn((2, 3)).to(torch.float32)\n",
    "a = a.unsqueeze(0)\n",
    "\n",
    "torch.bmm(a, a.transpose(1, 2))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "torch.matmul\n",
    "- If both tensors are 1-dimensional, the dot product (scalar) is returned.\n",
    "- If both arguments are 2-dimensional, the matrix-matrix product is returned.\n",
    "- If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.\n",
    "- If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.\n",
    "- If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). For example, if input is a \n",
    "(j × 1 × n × n) tensor and other is a (k × n × n) tensor, out will be a (j × k × n × n) tensor.\n",
    "\n",
    "Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs are broadcastable, and not the matrix dimensions. For example, if input is a \n",
    "(j × 1 × n × m) tensor and other is a (k × m × p) tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the matrix dimensions) are different. out will be a (j × k × n × p) tensor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(5), torch.Size([]))"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# vector * vector\n",
    "tensor1 = torch.arange(3)\n",
    "tensor2 = torch.arange(3)\n",
    "\n",
    "torch.matmul(tensor1, tensor2), torch.matmul(tensor1, tensor2).size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([5]), torch.Size([1]))"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# matrix * vector\n",
    "tensor1 = torch.arange(3).unsqueeze(0)\n",
    "tensor2 = torch.arange(3)\n",
    "\n",
    "torch.matmul(tensor1, tensor2), torch.matmul(tensor1, tensor2).size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 5, 14]]), torch.Size([1, 2]))"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# matrix * broadcasted vector\n",
    "tensor1 = torch.arange(6).view(2, 3).unsqueeze(0)\n",
    "tensor2 = torch.arange(3)\n",
    "\n",
    "torch.matmul(tensor1, tensor2), torch.matmul(tensor1, tensor2).size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[[10, 13],\n",
       "          [28, 40]]]),\n",
       " torch.Size([1, 2, 2]))"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# matrix * batched vector\n",
    "tensor1 = torch.arange(6).view(2, 3).unsqueeze(0)\n",
    "tensor2 = torch.arange(6).view(3, 2).unsqueeze(0)\n",
    "\n",
    "torch.matmul(tensor1, tensor2), torch.matmul(tensor1, tensor2).size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[[10, 13],\n",
       "          [28, 40]]]),\n",
       " torch.Size([1, 2, 2]))"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# matrix * broadcasted vector\n",
    "tensor1 = torch.arange(6).view(2, 3).unsqueeze(0)\n",
    "tensor2 = torch.arange(6).view(3, 2)\n",
    "\n",
    "torch.matmul(tensor1, tensor2), torch.matmul(tensor1, tensor2).size()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">*</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensor1 = torch.arange(6).view(2, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0,  2,  4],\n",
       "         [ 6,  8, 10]]])"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor1 * 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 0,  1,  4],\n",
       "         [ 9, 16, 25]]])"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor1 * tensor1"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">F.normalize</span>\n",
    "\n",
    "$v = \\frac{v}{max(||v||_{p}, \\epsilon)}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 1., 2., 3., 4.]])"
      ]
     },
     "execution_count": 163,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.arange(5).unsqueeze(0).to(torch.float32)\n",
    "\n",
    "inps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 0.1826, 0.3651, 0.5477, 0.7303]])"
      ]
     },
     "execution_count": 164,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.normalize(inps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 0.1000, 0.2000, 0.3000, 0.4000]])"
      ]
     },
     "execution_count": 165,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.normalize(inps, p=1)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">non-zero normalize</span>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">softmax</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inps = torch.tensor([[[1, 1, 1], [2, 2, 2]], \n",
    "                     [[3, 3, 3], [4, 4, 4]]], dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.2689, 0.2689, 0.2689],\n",
       "         [0.7311, 0.7311, 0.7311]],\n",
       "\n",
       "        [[0.2689, 0.2689, 0.2689],\n",
       "         [0.7311, 0.7311, 0.7311]]])"
      ]
     },
     "execution_count": 265,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "softmax_1 = nn.Softmax(1)\n",
    "out_1 = softmax_1(inps)\n",
    "out_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.3333, 0.3333, 0.3333],\n",
       "         [0.3333, 0.3333, 0.3333]],\n",
       "\n",
       "        [[0.3333, 0.3333, 0.3333],\n",
       "         [0.3333, 0.3333, 0.3333]]])"
      ]
     },
     "execution_count": 266,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "softmax_2 = nn.Softmax(2)\n",
    "out_2 = softmax_2(inps)\n",
    "out_2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">sigmoid</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.7311, 0.7311, 0.7311],\n",
       "         [0.8808, 0.8808, 0.8808]],\n",
       "\n",
       "        [[0.9526, 0.9526, 0.9526],\n",
       "         [0.9820, 0.9820, 0.9820]]])"
      ]
     },
     "execution_count": 267,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sigmoid = nn.Sigmoid()\n",
    "out_sig = sigmoid(inps)\n",
    "out_sig"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">masked_fill_</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0, 0, 3, 3],\n",
       "        [3, 3, 2, 3, 3]])"
      ]
     },
     "execution_count": 268,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([[0, 0, 0, 0, 0], [2, 2, 2, 2, 2]])\n",
    "mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)\n",
    "\n",
    "x.masked_fill_(mask, 3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">masked_select</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.0918, -0.1143, -2.2332,  0.2482],\n",
       "        [ 0.3171,  2.8210,  0.3838,  3.0182],\n",
       "        [ 0.4235,  0.3183, -0.5406,  1.1038]])"
      ]
     },
     "execution_count": 269,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.randn(3, 4)\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[False, False, False, False],\n",
       "        [False,  True, False,  True],\n",
       "        [False, False, False,  True]])"
      ]
     },
     "execution_count": 270,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mask = x.ge(0.5)\n",
    "mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2.8210, 3.0182, 1.1038])"
      ]
     },
     "execution_count": 271,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.masked_select(x, mask)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">masked_scatter</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0, 0, 0, 1],\n",
       "        [2, 3, 0, 4, 5]])"
      ]
     },
     "execution_count": 272,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([0, 0, 0, 0, 0])\n",
    "mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]], dtype=torch.bool)\n",
    "source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])\n",
    "x.masked_scatter(mask, source)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h3 style=\"color: Salmon;\">List</h3>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
      ]
     },
     "execution_count": 273,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = [0] * 10\n",
    "inps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]"
      ]
     },
     "execution_count": 274,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = [i for i in range(10)]\n",
    "inps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]"
      ]
     },
     "execution_count": 275,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = inps * 2\n",
    "inps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[2, 4, 6], [8, 10, 12]]"
      ]
     },
     "execution_count": 276,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = [[1, 2, 3], [4, 5, 6]]\n",
    "out = [list(map(lambda x: x * 2, sublist)) for sublist in inps]\n",
    "out"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">reverse</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]"
      ]
     },
     "execution_count": 277,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = [i for i in range(10)]\n",
    "inps.reverse()\n",
    "inps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9\n",
      "8\n",
      "7\n",
      "6\n",
      "5\n",
      "4\n",
      "3\n",
      "2\n",
      "1\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "inps = [i for i in range(10)]\n",
    "out = reversed(inps)\n",
    "for i in out:\n",
    "    print(i)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">append</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11]"
      ]
     },
     "execution_count": 279,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps.append(11)\n",
    "inps"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">remove</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 1, 2, 4, 5, 6, 7, 8, 9, 11]"
      ]
     },
     "execution_count": 280,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps.remove(3)\n",
    "inps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 1, 2, 4, 5, 6, 7, 8, 9]"
      ]
     },
     "execution_count": 281,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps.remove(11)\n",
    "inps"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">set</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1, 2, 3}"
      ]
     },
     "execution_count": 282,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = [1, 1, 1, 2, 2, 2, 3, 3, 3]\n",
    "set(inps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "set"
      ]
     },
     "execution_count": 283,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(set(inps))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">*List</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1, 1, 1, 2, 2, 2, 3, 3, 3]"
      ]
     },
     "execution_count": 284,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 1 1 2 2 2 3 3 3\n"
     ]
    }
   ],
   "source": [
    "print(*inps)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h3 style=\"color: Salmon;\">Statistical</h3>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">std</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.707825127659933"
      ]
     },
     "execution_count": 286,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = np.array([[1, 2, 3], [4, 5, 6]])\n",
    "inps.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.81649658, 0.81649658])"
      ]
     },
     "execution_count": 287,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps.std(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1.5, 1.5, 1.5])"
      ]
     },
     "execution_count": 288,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps.std(0)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">max</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6"
      ]
     },
     "execution_count": 289,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([4, 5, 6])"
      ]
     },
     "execution_count": 290,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps.max(0)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">min</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 291,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps.min()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">unique</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 2, 3, 4, 5, 6])"
      ]
     },
     "execution_count": 292,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ele = np.unique(inps)\n",
    "ele"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">argmax / argmin</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5"
      ]
     },
     "execution_count": 293,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps.argmax()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 1, 1])"
      ]
     },
     "execution_count": 294,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps.argmax(0)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h3 style=\"color: Salmon;\">torch.nn.Module</h3>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">ModuleList</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ModuleList(\n",
       "  (0): Linear(in_features=10, out_features=10, bias=True)\n",
       "  (1): ReLU()\n",
       "  (2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)\n",
       "  (3): Linear(in_features=10, out_features=5, bias=True)\n",
       "  (4): ReLU()\n",
       "  (5): LayerNorm((5,), eps=1e-05, elementwise_affine=True)\n",
       ")"
      ]
     },
     "execution_count": 295,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_list = nn.ModuleList()\n",
    "\n",
    "model_list.append(nn.Linear(10, 10))\n",
    "model_list.append(nn.ReLU())\n",
    "model_list.append(nn.LayerNorm(10))\n",
    "model_list.append(nn.Linear(10, 5))\n",
    "model_list.append(nn.ReLU())\n",
    "model_list.append(nn.LayerNorm(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 5, 5])"
      ]
     },
     "execution_count": 296,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((4, 5, 10), dtype=torch.float32)\n",
    "\n",
    "for blk in model_list:\n",
    "    inps = blk(inps)\n",
    "inps.size()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">Sequential</span>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nn.Sequential(\n",
    "    nn.Linear(10, 10),\n",
    "    nn.ReLU(),\n",
    "    nn.LayerNorm(10),\n",
    "    nn.Linear(10, 5),\n",
    "    nn.ReLU(),\n",
    "    nn.LayerNorm(5),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 5, 5])"
      ]
     },
     "execution_count": 298,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((4, 5, 10), dtype=torch.float32)\n",
    "\n",
    "out = model(inps)\n",
    "out.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_ = nn.Sequential(*model_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 5, 5])"
      ]
     },
     "execution_count": 300,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inps = torch.randn((4, 5, 10), dtype=torch.float32)\n",
    "\n",
    "out = model_(inps)\n",
    "out.size()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<span style=\"color: DodgerBlue;\">RegisterBuffer</span>\n",
    "\n",
    "用于保存不使用梯度学习的参数，但此参数可以使用 EMA 等方法进行更新"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 1.]])"
      ]
     },
     "execution_count": 301,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class MyModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.register_buffer(\"momentum\", torch.ones(1, 2))\n",
    "\n",
    "model = MyModel()\n",
    "model.momentum"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
